#!/usr/bin/env python
# coding: utf-8

# In[2]:

import random
random.seed(100)

import os
import numpy as np
import scipy as sp
import scipy.stats
import mnist
import ot
import tqdm
import matplotlib.pyplot as plt
import pandas as pd
import cvxpy as cp

from utils import mnist_utilities, opt_utilities

# defining our method -------------------

def baryc_proj(source, target, method):
    
    n1 = source.shape[0]
    n2 = target.shape[0]   
    p = source.shape[1]
    a_ones, b_ones = np.ones((n1,)) / n1, np.ones((n2,)) / n2
    
    M = ot.dist(source, target)
    M = M.astype('float64')
    M /= M.max()
    
    if method == 'emd':
        OTplan = ot.emd(a_ones, b_ones, M, numItermax = 1e7)
        
    elif method == 'entropic':
        OTplan = ot.bregman.sinkhorn_stabilized(a_ones, b_ones, M, reg = 5*1e-3)
    
    # initialization
    OTmap = np.empty((0, p))

    for i in range(n1):
        
        # normalization
        OTplan[i,:] = OTplan[i,:] / sum(OTplan[i,:])
    
        # conditional expectation
        OTmap = np.vstack([OTmap, (target.T @ OTplan[i,:])])
    
    OTmap = np.array(OTmap).astype('float32')
    
    return(OTmap)


def DSCreplication(target, controls, method = 'emd', projtype = 'wass'):
    
    n = target.shape[0]
    d = target.shape[1]
    J = len(controls)
    S = np.mean(target)*n*d*J # Stabilizer: to ground the optimization objective
    
    
    # Barycentric Projection
    G_list = []
    proj_list = []
    for i in range(len(controls)):
        temp = baryc_proj(target, controls[i], method)
        G_list.append(temp)
        proj_list.append(temp - target)
    
    
    # Obtain optimal weights
    mylambda = cp.Variable(J)

    objective = cp.Minimize(
                    cp.sum_squares(
                    cp.sum([a*b for a,b in zip(mylambda, proj_list)], axis = 0))/S
                    )
    
    constraints = [mylambda >= 0, mylambda <= 1, cp.sum(mylambda) == 1]

    prob = cp.Problem(objective, constraints)
    prob.solve()

    weights = mylambda.value
    testproj = sum([a*b for a,b in zip(weights, G_list)])
    measureweights = [ot.unif(n)]*J
    print('optimized')
    
    if projtype == 'eucl':
        projection = testproj
    elif projtype == 'wass':
        projection = ot.lp.free_support_barycenter(G_list, measureweights, X_init = testproj, 
                                                               weights = weights)
    
    return(weights, projection)


def euclideanprojection(target, controls, projtype = 'eucl'):
    
    n = target.shape[0]
    d = target.shape[1]
    J = len(controls)
    S = np.mean(target)*n*d*J # Stabilizer: to ground the optimization objective
    
   
    proj_list = []
    for i in range(len(controls)):
        proj_list.append((target - controls[i]))
    
    
    # Obtain optimal weights
    mylambda = cp.Variable(J)

    objective = cp.Minimize(
                    cp.sum_squares(
                    cp.sum([a*b for a,b in zip(mylambda, proj_list)], axis = 0))/S)
    
    constraints = [mylambda >= 0, mylambda <= 1, cp.sum(mylambda) == 1]

    prob = cp.Problem(objective, constraints)
    prob.solve()

    weights = mylambda.value
    testproj = sum([a*b for a,b in zip(weights, controls)])
    measureweights = [ot.unif(n)]*J
    xinit = np.random.normal(0, 1, (n,d))
    print('optimized')
    
    if projtype == 'eucl':
        projection = testproj
    elif projtype == 'wass':
        projection = ot.lp.free_support_barycenter(controls, measureweights, X_init = testproj, 
                                                               weights = weights)
    
    return(weights, projection)



# Load the data: change this to where you want MNIST saved to / where it is saved
mnist.temporary_dir = lambda: '.../mnist'

train_images = mnist.train_images()
train_labels = mnist.train_labels()

# Following Werenski et all
# ============= GENERATE REF AND TARGS ============= 

# set up the dataset (noiseless)
digit = 4
nref_digit = 10
ntarg_digit = 500

indices = np.where(train_labels == digit)[0]
perm = np.random.permutation(len(indices))
ref_inds = indices[perm[0:nref_digit]]
targ_inds = indices[nref_digit:nref_digit+ntarg_digit]

ref_digits = [train_images[ri] for ri in ref_inds]
targ_digits = [train_images[ti] for ti in targ_inds]        
for i in range(nref_digit):
    ref_digits[i] = ref_digits[i] / ref_digits[i].sum()
for i in range(ntarg_digit):
    targ_digits[i] = targ_digits[i] / targ_digits[i].sum()
ref_digits = np.array(ref_digits)


# run this for occlussion -------

mask = np.ones((28,28))
mask[10:18,10:18] = 0

ref_perts = [] # apply the mask
targ_perts = [] # apply the mask
        
for i in range(nref_digit):
    ref_digits[i] = ref_digits[i] / ref_digits[i].sum()
    ref_pert = ref_digits[i] * mask
    ref_pert = ref_pert / ref_pert.sum()
    ref_perts += [ref_pert]
    
for i in range(ntarg_digit):
    targ_digits[i] = targ_digits[i] / targ_digits[i].sum()
    targ_pert = targ_digits[i] * mask
    targ_pert = targ_pert / targ_pert.sum()
    targ_perts += [targ_pert]

ref_perts = np.array(ref_perts)

# ============= RECOVER LAMBDA ============= 

print("Recovering Lambda")
print("  No Entropy")

# no entropy
noe_lams = np.zeros((ntarg_digit,nref_digit))
for i in tqdm.tqdm(range(ntarg_digit)):
    targ_pert = targ_perts[i]
    
    # compute the matrix A (filled with inner products)
    A = mnist_utilities.inner_products(targ_pert, ref_perts)
    
    # recovers the estimate of lambda by solving
    lam = opt_utilities.solve(A)
    
    noe_lams[i,:] = lam

print("  Some Entropy")

# some entropy
supp = []
for i in range(28):
    for j in range(28):
        supp += [[i,j]]
supp = np.array(supp)

ent_lams = np.zeros((ntarg_digit,nref_digit))
for i in tqdm.tqdm(range(ntarg_digit)):
    targ_pert = targ_perts[i]
    
    # compute the matrix A (filled with inner products)
    A = mnist_utilities.entropic_inner_products(targ_pert, ref_perts, entropy=10, supp=supp)
    
    # recovers the estimate of lambda by solving
    lam = opt_utilities.solve(A)
    
    ent_lams[i,:] = lam

# ============= RECOVER TARGETS ============= 
print("Recovering Digit")

linspace = np.arange(28)
M = ot.dist(supp, metric='sqeuclidean')
M = np.asarray(M, dtype=np.float64)

noe_dists = np.zeros(ntarg_digit)
ent_dists = np.zeros(ntarg_digit)
for i in tqdm.tqdm(range(ntarg_digit)):
    noe_lam = noe_lams[i]
    ent_lam = ent_lams[i]
    targ_digit = targ_digits[i]
    
    noe_bc = mnist_utilities.barycenter(ref_digits, noe_lam, entropy=0.001)
    ent_bc = mnist_utilities.barycenter(ref_digits, ent_lam, entropy=0.001)
    
    noe_unwrap = mnist_utilities.unwrap_image(noe_bc, [0,28,0,28])
    ent_unwrap = mnist_utilities.unwrap_image(ent_bc, [0,28,0,28])
    targ_unwrap = mnist_utilities.unwrap_image(targ_digit, [0,28,0,28])
    
    noe_unwrap = noe_unwrap / noe_unwrap.sum()
    ent_unwrap = ent_unwrap / ent_unwrap.sum()
    targ_unwrap = targ_unwrap / targ_unwrap.sum()
    
    noe_dist = ot.lp.emd2(noe_unwrap, targ_unwrap, M)
    ent_dist = ot.lp.emd2(ent_unwrap, targ_unwrap, M)
    
    noe_dists[i] = noe_dist
    ent_dists[i] = ent_dist

print(f"  No Entropy {noe_dists.mean()}")
print(f"  Entropy {ent_dists.mean()}")

# pick a random perturbed target and generate a row of Figure 4
i = np.random.randint(0,ntarg_digit)

# our implementation
dsc_weights, dsc_opt = DSCreplication(targ_digits[i], ref_digits)

# plot results ----------------------------

ax = plt.subplot(1,5,1)
ax.imshow(targ_perts[i])
ax.axis('off')

ax = plt.subplot(1,5,2)
proj = opt_utilities.linear_projection(targ_digits[i].reshape(-1), ref_digits.reshape((nref_digit,-1)))
proj = proj.reshape((28,28))
ax.imshow(proj)
ax.axis('off')

ax = plt.subplot(1,5,3)
bc = mnist_utilities.barycenter(ref_digits, ent_lams[i], entropy=0.001, threshold=0.001)
ax.imshow(bc)
ax.axis('off')

ax = plt.subplot(1,5,4)
ax.imshow(dsc_opt)
ax.axis('off')

ax = plt.subplot(1,5,5)
ax.imshow(targ_digits[i])
ax.axis('off')

plt.savefig("experiment.jpg")
plt.show()
